#define ARGMAX_IDX(k) (s + spatial_dim * (k + channels * n))
template <typename T>
__device__ void argmax_forward(T *input, T *output, int num, int channels, int spatial_dim) {
  int n = threadIdx.x + blockIdx.x * blockDim.x;
  int s = threadIdx.y + blockIdx.y * blockDim.y;
  if (n >= num || s >= spatial_dim)
    return;

  int idx = s + spatial_dim * n;

  T maxval = input[ARGMAX_IDX(0)];
  int maxidx = 0;
  for (int i = 1; i < channels; ++i) {
    if (input[ARGMAX_IDX(i)] > maxval) {
      maxval = input[ARGMAX_IDX(i)];
      maxidx = i;
    }
  }
  output[idx] = maxidx;
}

extern "C" {
  __global__ void argmax_forward_float(float *input, float *output, int num, int channels, int spatial_dim) {
    argmax_forward(input, output, num, channels, spatial_dim);
  }
  __global__ void argmax_forward_double(double *input, double *output, int num, int channels, int spatial_dim) {
    argmax_forward(input, output, num, channels, spatial_dim);
  }
}

// vim: ft=cuda
